Skip to content

Conversation

@nastya236
Copy link
Contributor

@nastya236 nastya236 commented Nov 18, 2025

This PR adds a new operation mx.qqmm. The current structure is probably neither optimal nor final.

General comment

  1. For inference we want to support: qqmm(quantized weights, bf16 activations).
  2. For training (vjp) we unfortunately still need bf16 weights for two reasons:
    • We currently do not have 2D scaling for nvfp4, so we need to transpose and quantize again along a different dimension.
    • For mxfp8, the recommended recipe is to quantize with 1D blocks and keep two views of the weights (normal and transposed).

Therefore, mx.qqmm takes bf16 activations x, quantized weights w_q and their scales, and optionally bf16 weights plus group_size, mode, and bits.

In the current implementation, it is the user’s responsibility to ensure that group_size, bits, and mode match those used to quantize w_q. This is probably not ideal, and we may want to improve this in the future.

Very important details

  1. scales are repacked on every call for both weights and activations. In the future, we probably want to:

    • Avoid repacking weight scales for inference.
    • Fuse quantization and repacking, and directly pack into swizzled layout in fp_quantize.
  2. Batched qqmm is currently not supported; inputs must be 2D. For now it is implemented this way because:

    • CUBLASLT_BATCH_MODE_STRIDED is not supported for scales.
    • CUBLASLT_BATCH_MODE_POINTER_ARRAY is not supported for arrays with block scaling.

We almost certainly want to add batching in the future, but for simplicity batch_count = 1 for now.

  1. qqmm is always executed in TN layout (transpose = True).
    There are several reasons for this, but mainly we always quantize along the reduction dimension, which currently ends up being the last dimension.. I am happy to change this if you think that it is useful to support all layouts for mxfp8 for example. Also, only on B200 only TN layout is supported for nvfp4 and mxfp4.

Notes

  1. There are some changes to cublas_gemm.cpp: I grouped all common cuBLAS-related functions into a separate helper class in cublas_utils.cpp.
  2. mxfp8 qqmm behaves slightly differently from nvfp4: sometimes, for <<1% of the output elements, the result differs from the dequantized reference by exactly 1 ULP in bf16 (see python/tests/test_quantized.py, line 1027). I do not think this is a bug because:
  • For nvfp4 the output matches exactly for every tested shape.
  • The difference is not structured: there is no clear pattern, and the indices of the affected elements change with the seed.
  • The mismatch is always exactly 1 ULP.

Therefore, I attribute this to differences in accumulation on tensor cores or other numerical details we do not control.

What this PR lacks [these] because I first want to make sure the rest of the API looks reasonable

  1. addmm -- basically c is always nullptr
  2. nn.QQLinear
  3. nn.Linear.to_qqlinear - or similar method to cast to nn.QQLinear (naming is questionable)

Examples are in python/tests/test_quantized.py.
Happy to iterate and change anything here!

@nastya236 nastya236 marked this pull request as draft November 18, 2025 18:43
@nastya236 nastya236 changed the title qqmm [WIP] qqmm Nov 18, 2025
@nastya236 nastya236 marked this pull request as ready for review November 29, 2025 20:20
@nastya236 nastya236 changed the title [WIP] qqmm qqmm Nov 29, 2025
bool transpose_;
};

class DualQuantizedMatmul : public UnaryPrimitive {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit of a nit but I think it makes sense to rename this to QuantizedQuantizedMatmul or QQMatmul to better match the name of the op. Dual is also kind of an overloaded term.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree. I think QQMatmul is better, because then the primitive name and the op name are aligned.

bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
auto state() const {
return std::make_tuple(group_size_, bits_, mode_);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

transpose_ should be part of the state here.

Copy link
Contributor Author

@nastya236 nastya236 Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is a bit unclear and probably should be changed.. transpose is not a member variable, qqmm is always executed in TN layout (transpose = True). I did it this way because, at the moment, quantization always produces a row-major tensor with the last dimension packed, and TN is the only layout supported for mxfp4 and nvfp4 on B200.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see it below in the list under private:. Maybe it should be deleted?


ds = mx.grad(gmm)(s, x, wq)

def test_qqmm(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests will should only be run for now if mx.cuda.is_available().

And in fact I'm not sure what the behavior is on older hardware and CUDA toolkits. Do you know what the minimum requirements there are?

std::optional<int> bits_ /* = std::nullopt */,
const std::string& mode /* = "nvfp4" */,
StreamOrDevice s /* = {} */) {
// currently only simetric quantization is supported for qqmm
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// currently only simetric quantization is supported for qqmm
// currently only symmetric quantization is supported for qqmm

Comment on lines +4334 to +4338
if (qmode == QuantizationMode::Affine) {
std::ostringstream msg;
msg << "[qqmm] Affine quantization is not supported for qqmm.";
throw std::invalid_argument(msg.str());
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this was already checked above?

// https://docs.nvidia.com/cutlass/4.2.1/media/docs/cpp/blackwell_functionality.html
// because w_q should always be quantized along the reduction dimension
// and we quantize so that the last dim is packed, we assume that the last dim
// always the reduction dim so the firat argument in cubals column major is
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// always the reduction dim so the firat argument in cubals column major is
// is always the reduction dim so the first argument in cublas column major is

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment feels like it belongs in the cublas implementation rather than here.

auto [w_inner_dims, w_outer_dims] =
extract_qqmm_dims("qqmm", x, w_q, scales_w, w, group_size, bits);

// we don't backprope through qunatized w and scales
Copy link
Member

@awni awni Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// we don't backprope through qunatized w and scales
// we don't backprop through quantized w and scales

Comment on lines +4367 to +4368
auto dtype = bfloat16;
// out dtype can be only bf16 for now
Copy link
Member

@awni awni Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this limitation? It looks like the op can output bf16, fp16 or fp32: https://docs.nvidia.com/cuda/cublas/#id103.

The API should infer the output type from x.

Comment on lines +166 to +167
validate_quantized_input(
tag, w_q, scales_w, "weight matrix", "scales_w", group_size, bits);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can remove the strings here since x is not quantized. The original error message prior to the diff here makes sense.

Comment on lines +1413 to +1416
array x, // input activations
array w_q, // quantized weights
array w_scales,
std::optional<array> w = std::nullopt, // optional bf16 weights for vjp
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really love this API where sometimes it takes a w as input and sometimes not. I wonder if it makes sense to change it to something like:

Suggested change
array x, // input activations
array w_q, // quantized weights
array w_scales,
std::optional<array> w = std::nullopt, // optional bf16 weights for vjp
array x, // input activations
array w, // possibly quantized weights
std::optional<array> scales, // scales for w, if not provided `w` must be unquantized

So then it will quantize on the fly if w is not quantized and otherwise it will just use w as is.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And in order to take a vjp, w has to be provided unquantized.

bits_,
qmode,
s); // (K, N_packed), scales
vjps.push_back(qqmm(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A minor problem here is that this function is only once differentiable. I think changing the API as suggested above migth fix that. You always quantize the inputs on the fly when you want gradients.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants